{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import random\n",
    "# import pyepo\n",
    "import torch\n",
    "from torch import nn\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "import sys\n",
    "root_dir = \"../../\"\n",
    "sys.path.append(root_dir)\n",
    "from src.torch_Dijkstra import Dijkstra\n",
    "from src.dys_opt_net import DYS_opt_net\n",
    "from src.utils import node_to_edge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "warcraft_data_folder = root_dir + './src/warcraft/warcraft_data/'\n",
    "\n",
    "tmaps_train = np.load(warcraft_data_folder + \"12x12/train_maps.npy\")\n",
    "tmaps_val = np.load(warcraft_data_folder + \"12x12/val_maps.npy\")\n",
    "tmaps_test = np.load(warcraft_data_folder + \"12x12/test_maps.npy\")\n",
    "\n",
    "# convert to Float Tensor and permute so that the channels are first\n",
    "tmaps_train = torch.FloatTensor(tmaps_train).permute(0, 3, 1, 2)\n",
    "tmaps_val = torch.FloatTensor(tmaps_val).permute(0, 3, 1, 2)\n",
    "tmaps_test = torch.FloatTensor(tmaps_test).permute(0, 3, 1, 2)\n",
    "\n",
    "true_cost_train = np.load(warcraft_data_folder + \"12x12/train_vertex_weights.npy\")\n",
    "true_cost_val = np.load(warcraft_data_folder + \"12x12/val_vertex_weights.npy\")\n",
    "true_cost_test = np.load(warcraft_data_folder + \"12x12/test_vertex_weights.npy\")\n",
    "\n",
    "# convert to Float Tensor \n",
    "true_cost_train = torch.FloatTensor(true_cost_train)\n",
    "true_cost_val = torch.FloatTensor(true_cost_val)\n",
    "true_cost_test = torch.FloatTensor(true_cost_test)\n",
    "\n",
    "# true paths\n",
    "true_path_train = np.load(warcraft_data_folder + \"12x12/train_shortest_paths.npy\")\n",
    "true_path_val = np.load(warcraft_data_folder + \"12x12/val_shortest_paths.npy\")\n",
    "true_path_test = np.load(warcraft_data_folder + \"12x12/test_shortest_paths.npy\")\n",
    "\n",
    "# convert to Float Tensor \n",
    "true_path_train_vertex = torch.FloatTensor(true_path_train)\n",
    "true_path_val_vertex = torch.FloatTensor(true_path_val)\n",
    "true_path_test_vertex = torch.FloatTensor(true_path_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of edges =  1012\n"
     ]
    }
   ],
   "source": [
    " # Construct vertices. Note each vertex is at the middle of the square.\n",
    "m = 12 # grid size\n",
    "vertices = []\n",
    "for i in range(m):\n",
    "    for j in range(m):\n",
    "        vertices.append((i+0.5, j+0.5))\n",
    "\n",
    "num_vertices = len(vertices)\n",
    "\n",
    "# Construct edges.\n",
    "edge_list = []\n",
    "for i, v1 in enumerate(vertices):\n",
    "    for j, v2 in enumerate(vertices):\n",
    "        norm_squared = (v1[0] - v2[0])**2 + (v1[1] - v2[1])**2\n",
    "        if 0 < norm_squared < 2.01 and i < j:\n",
    "            edge_list.append((v1, v2))\n",
    "            edge_list.append((v2, v1)) # have to double edges to allow for travel both ways\n",
    "\n",
    "num_edges = len(edge_list)\n",
    "print('number of edges = ', num_edges)\n",
    "\n",
    "# ## Small utility for rounding coordinates of points\n",
    "# def round_coordinates(vertex_name):\n",
    "#   vertex_coord = [int(vertex_name[0]), int(vertex_name[1])]\n",
    "#   return vertex_coord\n",
    "\n",
    "## vertex-edge incidence matrix\n",
    "A = torch.zeros((len(vertices), len(edge_list)))\n",
    "for j, e in enumerate(edge_list):\n",
    "    ind0 = vertices.index(e[0])\n",
    "    ind1 = vertices.index(e[1])\n",
    "    A[ind0,j] = -1.\n",
    "    A[ind1,j] = +1.\n",
    "\n",
    "## Create b vector necessary for LP approach to shortest path\n",
    "b = torch.zeros(m**2)\n",
    "b[0] = -1.0\n",
    "b[-1] = 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get paths in edge format as well"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_path_train_edge = node_to_edge(true_path_train_vertex, edge_list, four_neighbors=False)\n",
    "true_path_val_edge = node_to_edge(true_path_val_vertex, edge_list, four_neighbors=False)\n",
    "true_path_test_edge = node_to_edge(true_path_test_vertex, edge_list, four_neighbors=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Tensor Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = TensorDataset(tmaps_train, true_path_train_edge, true_path_train_vertex, true_cost_train)\n",
    "val_dataset = TensorDataset(tmaps_val, true_path_val_edge, true_path_val_vertex, true_cost_val)\n",
    "test_dataset = TensorDataset(tmaps_test, true_path_test_edge, true_path_test_vertex, true_cost_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a dictionary of the PyTorch variables\n",
    "state_dict = {\n",
    "'train_dataset': train_dataset,\n",
    "'val_dataset': val_dataset,\n",
    "'test_dataset': test_dataset,\n",
    "'m': 12, # vertex grid size\n",
    "'A': A,\n",
    "'b': b,\n",
    "'num_edges': num_edges,\n",
    "'edge_list': edge_list\n",
    "}\n",
    "\n",
    "save_path = warcraft_data_folder + \"Warcraft_training_data12.pth\"\n",
    "\n",
    "# Save the dictionary to a file\n",
    "torch.save(state_dict, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_dataset_v = TensorDataset(tmaps_train, true_path_train, true_cost_train)\n",
    "# true_path_train_edge = node_to_edge(true_path_train, edges, four_neighbors=False)\n",
    "# train_dataset_e = TensorDataset(tmaps_train, true_path_train_edge, true_path_train_vertex, true_cost_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# val_dataset_v = TensorDataset(tmaps_val, true_path_val, true_cost_val)\n",
    "# true_path_val_e = node_to_edge(true_path_val, edges, four_neighbors=False)\n",
    "# val_dataset_e = TensorDataset(tmaps_val, true_path_val_e, true_cost_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_dataset_v = TensorDataset(tmaps_test, true_path_test, true_cost_test)\n",
    "# true_path_test_e = node_to_edge(true_path_test, edges, four_neighbors=False)\n",
    "# test_dataset_e = TensorDataset(tmaps_test, true_path_test_e, true_cost_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Create a dictionary of the PyTorch variables\n",
    "# state_dict = {\n",
    "# 'train_dataset_v': train_dataset_v,\n",
    "# 'train_dataset_e': train_dataset_e,\n",
    "# 'val_dataset_v': val_dataset_v,\n",
    "# 'val_dataset_e': val_dataset_e,\n",
    "# 'test_dataset_v': test_dataset_v,\n",
    "# 'test_dataset_e': test_dataset_e,\n",
    "# 'm': 12, # vertex grid size\n",
    "# 'A': A,\n",
    "# 'b': b,\n",
    "# 'num_edges': num_edges,\n",
    "# 'edge_list': edges\n",
    "# }\n",
    "\n",
    "# save_path = warcraft_data_folder + \"Warcraft_training_data12.pth\"\n",
    "\n",
    "# # Save the dictionary to a file\n",
    "# torch.save(state_dict, save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fpo_dys_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
